import os
import numpy as np
import shutil

root_path = "/home/h/搅拌站/盐城/data_1_mix_0705_lower"
# train_path = "/home/h/桌面/predict34_data/train"
val_path =  "/home/h/搅拌站/盐城/data_1_mix_0705_lower/val"
# test_path =  "/home/h/桌面/predict34_data/test"
train_ratio = 0.8
val_ratio = 0.2
path = os.listdir(os.path.join(root_path,"train"))
for j in range(len(path)):
    train_path = os.path.join(root_path,"train",path[j])
    print(train_path)
    image_list = sorted([os.path.join(train_path,f)  for f in os.listdir(train_path) if f.endswith(".jpg")])
    train_size = int(np.floor(len(image_list) * train_ratio))
    val_size = int(np.floor(len(image_list) * val_ratio))
    # # len(image_list)-train_size
    test_size = len(image_list )- train_size - val_size
    arr = list(range(len(image_list)))
    # # train_data_index = np.sort(np.random.choice(arr,train_size,replace=False))
    # # for i in range(0, len(train_data_index)):
    # #     arr.remove(train_data_index[i])
    val_data_index = np.sort(np.random.choice(arr,val_size,replace=False))
    for i in range(0, len(val_data_index)):
        arr.remove(val_data_index[i])
    test_data_index = np.sort(np.random.choice(arr,test_size,replace=False))
    for i in range(0, len(test_data_index)):
        arr.remove(test_data_index[i])
    # # test_data_index = arr
    # # print(len(train_data_index),len(val_data_index),len(test_data_index))
    # # for i in range(len(train_data_index)):
    # #     # print(image_list[val_data_index[i]].split("/")[-1])
    # #     # path = os.path.join(new_path,image_list[val_data_index[i]].split("/")[-1])
    # #     shutil.move(image_list[train_data_index[i]],train_path)
    # #     print("train_image:",image_list[train_data_index[i]])
    for i in range(len(val_data_index)):
        # print(image_list[val_data_index[i]].split("/")[-1])
        # path = os.path.join(new_path,image_list[val_data_index[i]].split("/")[-1])
        shutil.move(image_list[val_data_index[i]],os.path.join(val_path,path[j]))
        print("val_image:",image_list[val_data_index[i]])
    # for i in range(len(test_data_index)):
    #     shutil.move(image_list[test_data_index[i]],os.path.join(test_path,path[i]))
    #     print("test_image:",image_list[val_data_index[i]])
    print("train_size:"+ str(train_size)+"val_size"+str(val_size)+"test_size"+str(test_size))